import csv
import numpy as np

dims = (4,1,5,5)
PGD_tr_loss = np.zeros(dims)
PGD_tr_sa = np.zeros(dims)
PGD_tr_ra = np.zeros(dims)
PGD_test_loss = np.zeros(dims)
PGD_test_sa = np.zeros(dims)
PGD_test_ra = np.zeros(dims)

NPGD_tr_loss = np.zeros(dims)
NPGD_tr_sa = np.zeros(dims)
NPGD_tr_ra = np.zeros(dims)
NPGD_test_loss = np.zeros(dims)
NPGD_test_sa = np.zeros(dims)
NPGD_test_ra = np.zeros(dims)

for ii, i in enumerate([2, 4, 8, 16]): # epsilon, size of attack
    for jj, j in enumerate(['1']): # learning rate fixed
        for kk, k in enumerate([2, 4, 8, 16, 32]): # how many steps of pgd
            for ll, l in enumerate([1, 2, 3, 4, 5]): # seed number
                file_name = 'evaluation/Evaluation_CIFAR10_PGD_attack_CIFAR10_PGD_pgd%d_eps%d_lr1_seed%d_best.csv'%(k,i,l)
                file = open(file_name)
                csvreader = csv.reader(file)
                rows = []
                for row in csvreader:
                    rows.append(row)
                PGD_test_ra[ii,jj,kk,ll] = rows[-1][0]

                
                file_name = 'evaluation/Evaluation_CIFAR10_PGD_attack_CIFAR10_NPGD_npgd%d_eps%d_lr1_seed%d_best.csv'%(k,i,l)
                file = open(file_name)
                csvreader = csv.reader(file)
                rows = []
                for row in csvreader:
                    rows.append(row)
                NPGD_test_ra[ii,jj,kk,ll] = rows[-1][0]

                
PGD_std_ra = np.std(PGD_test_ra,axis=3)
NPGD_std_ra = np.std(NPGD_test_ra,axis=3)
PGD_test_ra = np.mean(PGD_test_ra,axis=3)
NPGD_test_ra = np.mean(NPGD_test_ra,axis=3)


print('epsilon=2, different number of pgd steps [2 4 8 16 32]')
print('PGD', PGD_test_ra[0])
print('NPGD', NPGD_test_ra[0])

print('epsilon=4, different number of pgd steps [2 4 8 16 32]')
print('PGD', PGD_test_ra[1])
print('NPGD', NPGD_test_ra[1])

print('epsilon=8, different number of pgd steps [2 4 8 16 32]')
print('PGD', PGD_test_ra[2])
print('NPGD', NPGD_test_ra[2])

print('epsilon=16, different number of pgd steps [2 4 8 16 32]')
print('PGD', PGD_test_ra[3])
print('NPGD', NPGD_test_ra[3])

